// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <sstream>
#include "paddle/common/overloaded.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h"

namespace symbol {

template <typename T>
class ShapeOrData {
 public:
  explicit ShapeOrData(const std::vector<T>& shape)
      : shape_(shape), data_(std::nullopt) {}
  explicit ShapeOrData(const std::vector<T>& shape, const std::vector<T>& data)
      : shape_(shape), data_(data) {
    // Valid check
    if (shape.size() == 0) {
      PADDLE_ENFORCE_EQ(
          data.size(),
          1UL,
          common::errors::InvalidArgument(
              "When shape is 0-D, size of data should be 1, but got %d.",
              data.size()));
    } else if (shape.size() == 1) {
      PADDLE_ENFORCE_EQ(shape[0].template Has<int64_t>(),
                        true,
                        common::errors::InvalidArgument(
                            "When shape is 1-D, value of shape should be int"));
      PADDLE_ENFORCE_EQ(
          shape[0].template Get<int64_t>() == static_cast<int64_t>(data.size()),
          true,
          common::errors::InvalidArgument(
              "When shape is 1-D, size of data should be the same as "
              "value[%d] of shape, but got [%d].",
              shape[0].template Get<std::int64_t>(),
              data.size()));
    } else {
      int64_t numel = 1;
      for (const auto& expr : shape) {
        PADDLE_ENFORCE_EQ(expr.template isa<int64_t>(),
                          true,
                          ::common::errors::InvalidArgument(
                              "When data has value, the expr of shape should "
                              "be int, but got %s.",
                              ToString(expr)));
        numel *= expr.template Get<int64_t>();
      }
      PADDLE_ENFORCE_EQ(numel,
                        data.size(),
                        ::common::errors::InvalidArgument(
                            "Size of data should be the same as "
                            "product of value[%d] of shape, but got [%d].",
                            numel,
                            data.size()));
    }
  }

  ShapeOrData() = default;
  ShapeOrData(const ShapeOrData&) = default;
  ShapeOrData(ShapeOrData&&) = default;
  ShapeOrData& operator=(const ShapeOrData&) = default;
  ShapeOrData& operator=(ShapeOrData&&) = default;

  // Tensor's real shape
  const std::vector<T>& shape() const { return shape_; }
  // Specific for Tensor generated by shape-relevant ops
  const std::optional<std::vector<T>>& data() const { return data_; }
  void SetData(const std::vector<T>& data) { data_ = data; }

  bool operator==(const ShapeOrData<T>& other) const {
    if (data_.has_value() && !other.data_.has_value()) return false;
    if (!data_.has_value() && other.data_.has_value()) return false;
    if (shape_.size() != other.shape_.size()) return false;

    if (data_.has_value() && other.data_.has_value()) {
      if (data_.value().size() != other.data_.value().size()) return false;

      for (size_t i = 0; i < data_.value().size(); ++i) {
        DimExpr dim0 = symbol::SimplifyDimExpr(data_.value()[i]);
        DimExpr dim1 = symbol::SimplifyDimExpr(other.data_.value()[i]);
        if (dim0 != dim1) return false;
      }
    }

    for (size_t i = 0; i < shape_.size(); ++i) {
      DimExpr dim0 = symbol::SimplifyDimExpr(shape_[i]);
      DimExpr dim1 = symbol::SimplifyDimExpr(other.shape_[i]);
      if (dim0 != dim1) return false;
    }

    return true;
  }

  bool operator!=(const ShapeOrData<T>& other) const {
    return !(*this == other);
  }

 private:
  std::vector<T> shape_;
  std::optional<std::vector<T>> data_;
};

using NullShapeOrDataDimExpr = std::monostate;
using TensorShapeOrDataDimExprs = ShapeOrData<DimExpr>;
using TensorListShapeOrDataDimExprs = std::vector<TensorShapeOrDataDimExprs>;

/* TensorArray can append tensors dynamically. In a static graph, we only
 * store the shape of one element as a hint, because we assume that all elements
 * in the TensorArray have the same rank, and with equal constraints on specific
 * dimensions. */
class RankedTensorArrayShapeOrDataDimExprs {
 public:
  RankedTensorArrayShapeOrDataDimExprs() = default;
  explicit RankedTensorArrayShapeOrDataDimExprs(
      const std::vector<DimExpr>& shape)
      : shape_hint_{shape} {}
  const std::vector<DimExpr>& GetShapeHint() const { return shape_hint_; }
  bool operator==(const RankedTensorArrayShapeOrDataDimExprs& other) const {
    if (shape_hint_.size() != other.shape_hint_.size()) return false;
    for (size_t i = 0; i < shape_hint_.size(); ++i) {
      DimExpr dim0 = symbol::SimplifyDimExpr(shape_hint_[i]);
      DimExpr dim1 = symbol::SimplifyDimExpr(other.shape_hint_[i]);
      if (dim0 != dim1) return false;
    }

    return true;
  }

 private:
  std::vector<DimExpr> shape_hint_;
};

using ShapeOrDataDimExprsBase =
    std::variant<NullShapeOrDataDimExpr,
                 TensorShapeOrDataDimExprs,
                 TensorListShapeOrDataDimExprs,
                 RankedTensorArrayShapeOrDataDimExprs>;

class ShapeOrDataDimExprs : public ShapeOrDataDimExprsBase {
 public:
  ShapeOrDataDimExprs() = delete;
  ShapeOrDataDimExprs(
      const TensorShapeOrDataDimExprs& tensor_dim_exprs)  // NOLINT
      : ShapeOrDataDimExprsBase(tensor_dim_exprs) {}
  ShapeOrDataDimExprs(
      const TensorListShapeOrDataDimExprs& tensor_list_dim_exprs)
      : ShapeOrDataDimExprsBase(tensor_list_dim_exprs) {}

  ShapeOrDataDimExprs(const RankedTensorArrayShapeOrDataDimExprs&
                          tensor_array_dim_exprs)  // NOLINT
      : ShapeOrDataDimExprsBase(tensor_array_dim_exprs) {}

  ShapeOrDataDimExprs(const NullShapeOrDataDimExpr& null_dim_expr)  // NOLINT
      : ShapeOrDataDimExprsBase(null_dim_expr) {}

  template <typename T>
  bool isa() const {
    return std::holds_alternative<T>(*this);
  }

  template <typename T>
  const T& dyn_cast() const {
    return std::get<T>(*this);
  }

  const ShapeOrDataDimExprsBase& variant() const {
    return static_cast<const ShapeOrDataDimExprsBase&>(*this);
  }

  DEFINE_MATCH_METHOD();

  bool operator==(const ShapeOrDataDimExprs& other) const {
    return this->variant() == other.variant();
  }

  bool operator!=(const ShapeOrDataDimExprs& other) const {
    return !(*this == other);
  }

  const std::vector<DimExpr>& shape() const {
    PADDLE_ENFORCE_EQ(std::holds_alternative<TensorShapeOrDataDimExprs>(*this),
                      true,
                      common::errors::PreconditionNotMet(
                          "Shape of ShapeOrData is not a vector, "
                          "check whether the value is a "
                          "tensor-list or not."));
    return std::get<TensorShapeOrDataDimExprs>(*this).shape();
  }

  const std::optional<std::vector<DimExpr>>& data() const {
    PADDLE_ENFORCE_EQ(
        std::holds_alternative<TensorShapeOrDataDimExprs>(*this),
        true,
        common::errors::PreconditionNotMet(
            "Data of ShapeOrData is not a vector, check whether the value is a "
            "tensor-list or not."));
    return std::get<TensorShapeOrDataDimExprs>(*this).data();
  }

  void SetData(const std::vector<DimExpr>& data) {
    PADDLE_ENFORCE_EQ(
        std::holds_alternative<TensorShapeOrDataDimExprs>(*this),
        true,
        common::errors::PreconditionNotMet(
            "Data of ShapeOrData is not a vector, check whether the value is a "
            "tensor-list or not."));

    std::get<TensorShapeOrDataDimExprs>(*this).SetData(data);
  }
};

IR_API ShapeOrDataDimExprs SubstituteShapeOrData(
    const ShapeOrDataDimExprs& shape_or_data,
    const std::unordered_map<DimExpr, DimExpr>& substitution_pattern);

IR_API std::ostream& operator<<(std::ostream&,
                                const ShapeOrDataDimExprs& dim_expr);

}  // namespace symbol

namespace std {

template <>
struct hash<symbol::TensorShapeOrDataDimExprs> {
  std::size_t operator()(const symbol::TensorShapeOrDataDimExprs& obj) const {
    const auto hash_func = std::hash<std::vector<symbol::DimExpr>>();
    std::size_t ret = hash_func(obj.shape());
    ret = pir::detail::hash_combine(ret, obj.data().has_value());
    if (obj.data().has_value()) {
      ret = pir::detail::hash_combine(ret, hash_func(obj.data().value()));
    }
    return ret;
  }
};

template <>
struct hash<symbol::TensorListShapeOrDataDimExprs> {
  std::size_t operator()(
      const symbol::TensorListShapeOrDataDimExprs& obj) const {
    const auto hash_func = std::hash<symbol::TensorShapeOrDataDimExprs>();
    std::size_t ret = 0;
    for (const auto& shape_or_data : obj) {
      ret = pir::detail::hash_combine(ret, hash_func(shape_or_data));
    }
    return ret;
  }
};

template <>
struct hash<symbol::RankedTensorArrayShapeOrDataDimExprs> {
  std::size_t operator()(
      const symbol::RankedTensorArrayShapeOrDataDimExprs& obj) const {
    return std::hash<std::vector<symbol::DimExpr>>()(obj.GetShapeHint());
  }
};

template <>
struct hash<symbol::ShapeOrDataDimExprs> {
  std::size_t operator()(const symbol::ShapeOrDataDimExprs& obj) const {
    return obj.Match([](const auto& impl) {
      using T = std::decay_t<decltype(impl)>;
      return std::hash<T>()(impl);
    });
  }
};

}  // namespace std
